from ..layers.layer_utils import Conv2d_tucker, Conv2d_tucker_adaptive, Conv2d_tucker_fixed
from ..layers.linear_lr_new import DLRTLinearAdaptive  #### non tensorial layer
import torch
import torch.distributed as dist
import torch.nn as nn
from torchvision.models import resnet
from rich.console import Console
import tntorch
from einops import rearrange

import warnings

warnings.filterwarnings("ignore")
torch.set_default_dtype(torch.float32)

console = Console(width=140)


class DLRTNetwork(nn.Module):
    # abstraction of a wrapped torch network. Thiw will be used to call the functions for all the
    # layers. it will hold things which dont need to be in the trainer class
    def __init__(
            self,
            torch_model: nn.Module,
            rank_percent: float = None,
            adaptive: bool = True,
            tau: float = None,
            ddp_dlrt_layers: bool = False,
            dense_first_layer: bool = True,
            dense_last_layer: bool = True,
            # pretrain_count: int = 0,
            tucker=True,
            matrix_dlrt=False,
            baseline=False,
            chain_init=False,
            sequential=True,
            load_weights = True    #### new variable to load weights
    ):

        '''
        wrapper for a Pytorch network into DLRT network optimizable with the DLRT optimizer
        
        '''

        super().__init__()
        #### new test variables
        self.tucker = tucker
        self.matrix_dlrt = matrix_dlrt
        self.load_weights = load_weights
        #####

        self.adaptive = adaptive
        if tau is None:
            tau = {"linear": 0.0, "conv2d": 0.1}
        elif not isinstance(tau, dict):
            raise TypeError(
                f"epsilon must be a dict with a value for every type of DLRT layer ('linear, "
                f"conv2d', transformers), currently: {tau}",
            )
        if rank_percent and rank_percent > 1:
            raise ValueError(
                f"rank_percent should be less than 1, but got rank_percent={rank_percent}",
            )
        super().__init__()


        self.adaptive = adaptive
        self.rank_percent = rank_percent
        self.tau = tau
        self.baseline = baseline
        self.sequential = sequential
        # replace layers
        self.torch_model = torch_model
        self.reset_layers = None
        if not dist.is_initialized():
            self.ddp_dlrt_layers = False
            self.rank = 0
        else:
            self.ddp_dlrt_layers = ddp_dlrt_layers
            self.rank = dist.get_rank()

        self.dense_last_layer = dense_last_layer
        self.dense_first_layer = dense_first_layer
        self._dfl_wait = dense_first_layer

        self.wrap_model()
        if chain_init:
            self.chain_init()
        self.lr_modules = [module for module in self.lr_model.modules() if isinstance(module,Conv2d_tucker_adaptive) or 
                           isinstance(module,Conv2d_tucker_fixed) or isinstance(module,DLRTLinearAdaptive)]
        self.count_layer = len(self.lr_modules)

    @torch.no_grad()
    def wrap_model(self):  # wrapper function

        self.first_layer = None

        self.lr_model = self._replace_layers(self.torch_model)
        del self.torch_model  # Not needed

        if self.dense_last_layer:
            self._reset_last_layers_to_dense2(self.lr_model)

        if self.sequential:  # breaks if I dont do this for sequential models
            self.lr_model = torch.nn.Sequential(*self.lr_model.layer)

        if dist.is_initialized():
            for layer in self.lr_model.children():
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
        else:
            pass

    @torch.no_grad()
    def update_K_Q(self,tucker_module,w):

        factors = tntorch.round_tucker(tntorch.tensor.Tensor(w),eps = 1e-6, rmax=list(w.shape), dim='all',
                                       algorithm='svd')

        U0,U1,U2,U3 = factors.Us
        tucker_module.dynamic_rank = list(factors.ranks_tucker)
        C = factors.tucker_core()
        for i in range(len(w.shape)):
            ri = factors.ranks_tucker[i]
            other_ranks = [s for k,s in enumerate(factors.ranks_tucker) if k!=i]
            if i == 0:
                MAT_i_C =  rearrange(C,' i j k l -> i (j k l)')
            elif i == 1:
                MAT_i_C =  rearrange(C,' i j k l -> j (i k l)')
            elif i == 2:
                MAT_i_C =  rearrange(C,' i j k l -> k (i j l)')
            elif i == 3:
                MAT_i_C =  rearrange(C,' i j k l -> l (i j k)')
            Q_i, S_i_0_T = torch.linalg.qr(MAT_i_C.T)
            Q_ten = rearrange(Q_i.T,'ri (a b c) -> ri a b c',ri = ri,a = other_ranks[0],b = other_ranks[1],c = other_ranks[2])
            K_i = factors.Us[i]@(S_i_0_T.T)
            tucker_module.Ks[i][:,:ri] = K_i
            tucker_module.Qst[i][:ri,:other_ranks[0],:other_ranks[1],:other_ranks[2]] = Q_ten
            if i == 0:
                W = torch.einsum('abcd,al,bi,cj,dk->lijk',Q_ten,*[K_i.T,U1.T,U2.T,U3.T])
            elif i == 1:
                W = torch.einsum('bacd,al,bi,cj,dk->lijk',Q_ten,*[U0.T,K_i.T,U2.T,U3.T])
            elif i == 2:
                W = torch.einsum('cabd,al,bi,cj,dk->lijk',Q_ten,*[U0.T,U1.T,K_i.T,U3.T])
            elif i == 3:
                W = torch.einsum('dabc,al,bi,cj,dk->lijk',Q_ten,*[U0.T,U1.T,U2.T,K_i.T])
            #print(f'check tensorized version mode {i}, {torch.norm(W-w)/torch.norm(w)}')


    def load_layer_weights(self,full_rank_module,low_rank_module):

        '''
        new functionality to load pretrained model weights from full-rank format
        '''

        if isinstance(full_rank_module,nn.Conv2d):
            factors = tntorch.round_tucker(tntorch.tensor.Tensor(full_rank_module.weight.data),eps = 1e-6, rmax=low_rank_module.rmax, dim='all',
                                       algorithm='svd')
            for i in range(len(low_rank_module.dims)):
                rank1,rank2,rank3,rank4 = [el.shape[1] for el in factors.Us]
                r = factors.Us[i].shape[1]
                low_rank_module.Us[i].data[:,:r] = factors.Us[i]
                if isinstance(low_rank_module,Conv2d_tucker_adaptive):
                    low_rank_module.U_hats[i].data[:,:r] = factors.Us[i]   ###
                    low_rank_module.U_hats[i].data[:,r:2*r] = torch.zeros(size = factors.Us[i].shape)  ####
            low_rank_module.C.data[:rank1,:rank2,:rank3,:rank4] = factors.tucker_core()
            mask = torch.zeros(low_rank_module.C.data.shape)
            mask[:rank1,:rank2,:rank3,:rank4] = torch.ones(size= [rank1,rank2,rank3,rank4])
            low_rank_module.C.data = low_rank_module.C.data*mask
            low_rank_module.bias = full_rank_module.bias
            low_rank_module.dynamic_rank = [rank1,rank2,rank3,rank4]
            self.update_K_Q(low_rank_module,full_rank_module.weight)
        elif isinstance(full_rank_module,nn.Linear):
            U,S,V = torch.linalg.svd(full_rank_module.weight.data)
            U,S,V = U[:,:low_rank_module.dynamic_rank],torch.diag(S[:2*low_rank_module.dynamic_rank]),V.T[:,:low_rank_module.dynamic_rank]
            low_rank_module.U.data = U
            low_rank_module.V.data = V
            low_rank_module.S_hat.data = S
            low_rank_module.bias = full_rank_module.bias
        elif isinstance(full_rank_module,torch.nn.BatchNorm2d):
            low_rank_module.bias = full_rank_module.bias
            low_rank_module.weight = full_rank_module.weight
        else:
            low_rank_module.load_state_dict(full_rank_module.state_dict())

    def _replace_layers(self, module, pretrain=False, name=None, process_group=None):
        module_output = module
        # replaces every layer with the DLRT counterpart

        if isinstance(module, nn.Linear):
            if not self._dfl_wait:  # if not waiting i.e. already past the first layer
                if self.matrix_dlrt:
                    module_output = DLRTLinearAdaptive(in_features=module.in_features, out_features=module.out_features,
                                                       bias=module.bias is not None, tau=self.tau['linear'])
                    if self.load_weights:
                        self.load_layer_weights(module,module_output)
                        module_output.K_preprocess_step()
                        module_output.L_preprocess_step()
                else:
                    module_output = module
            else:  # dont wait -> is first layer -> should be dense
                module_output = module  ############ new
                self._dfl_wait = False
        elif isinstance(module, nn.Conv2d):
            if not self._dfl_wait:  # if not waiting i.e. already past the first layer
                if self.tucker:
                    module_output = Conv2d_tucker(
                        adaptive=self.adaptive,
                        low_rank_percent=self.rank_percent,
                        in_channels=module.in_channels,
                        out_channels=module.out_channels,
                        kernel_size=module.kernel_size,
                        stride=module.stride,
                        padding=module.padding,
                        dilation=module.dilation,
                        groups=module.groups,
                        bias=module.bias is not None,
                        padding_mode=module.padding_mode,
                        tau=self.tau["conv2d"]
                    ).to(device=module.weight.device, dtype=module.weight.dtype)
                    if self.load_weights:
                        self.load_layer_weights(module,module_output)
                        module_output.K_preprocess_step()
                else:
                    module_output = module
        else:  # dont wait -> is first layer -> should be dense
            module_output = module  ############ new
            self.load_layer_weights(module,module_output)
            self._dfl_wait = False
        self.reset_layers = [module, name]

        for name, child in module.named_children():
            module_output.add_module(name, self._replace_layers(child, pretrain=pretrain, name=name,
                                                                process_group=process_group))
        del module
        return module_output

    def _reset_last_layers_to_dense2(self, model):
        """
        Resets the last layer in the network to be full-rank, if it is low-rank
        """
        n_modules = len(list(model.modules()))
        for i,(n,module) in enumerate(model.named_modules()):
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d) and i ==n_modules-1:
                # Last layer is dense already
                break
            if (isinstance(module, Conv2d_tucker_adaptive) or isinstance(module, Conv2d_tucker_fixed) or isinstance(module,DLRTLinearAdaptive)) and i ==n_modules-1:
                
                device = module.device
                dtype = module.dtype
                if self.sequential:
                    model.layer[-1] = self.reset_layers[0].to(device=device, dtype=dtype)
                else:
                    model.add_module(module = self.reset_layers[0].to(device=device, dtype=dtype),name = 'last_linear')
                    setattr(model, n, self.reset_layers[0].to(device=device, dtype=dtype))
                del module
                break

    def _reset_last_layer_to_dense(self, module, name=None):

        module_output = module
        if name == self.reset_layers[1]:
            if hasattr(module, "weight"):
                device = module.weight.device
                dtype = module.weight.dtype
            else:
                try:
                    device = module.k.device
                    dtype = module.k.dtype
                except AttributeError:
                    device = None
            if device is not None:
                module_output = self.reset_layers[0].to(device=device, dtype=dtype)
        for name, child in module.named_children():
            module_output.add_module(name, self._reset_last_layer_to_dense(child, name))
        del module
        return module_output


    def eval(self):
        self.set_step('test')
        return self.train(False)


    def __run_command_on_dlrt_layers(self, module, command, kwargs=None):
        # NOTE: the command must be a member function of DLRTModule
        if kwargs is None:
            kwargs = {}

        if hasattr(module, "dlrt"):  ########### don't get why this line
            getattr(module, command)(**kwargs)

        for name, child in module.named_children():
            self.__run_command_on_dlrt_layers(child, command, kwargs)

    def train_weights_and_bias_(self, module, bool_flag):  # auxiliar function

        # NOTE: the command must be a member function of DLRTModule
        if not hasattr(module, "dlrt"):
            for n, p in module.named_parameters():
                if 'weight' in n.lower() or 'bias' in n.lower():
                    p.requires_grad = bool_flag

        for name, child in module.named_children():
            self.train_weights_and_bias_(child, bool_flag)

    def train_weights_and_bias(self, bool_flag):

        '''
        reset requires_grad attribute to bool_flag to weights and biases for standard layers
        '''

        self.train_weights_and_bias_(self.lr_model, bool_flag)

    def _set_step(self, module, step):

        # NOTE: the command must be a member function of DLRTModule

        if hasattr(module, "dlrt"):

            module.set_step(step)

        for name, child in module.named_children():
            self._set_step(child, step)

    def set_step(self, step):

        self._set_step(self.lr_model, step=step)


    def get_all_ranks(self):
        self.ranks = []
        self.__collect_ranks(self.lr_model)
        out_ranks = self.ranks.copy()
        self.ranks = []
        return out_ranks


    def forward(self, input):

        return self.lr_model(input)

    def populate_gradients(self, x, y, criterion, step='all'):

        '''
        function to populate all the gradients needed during the DLRT optimization step

        '''

        if (not self.baseline) and step == 'all':

            self.train_weights_and_bias(False)  # FALSE TO BE CORRECT
            self.set_step(step=1)
            output = self.forward(x)
            loss1 = criterion(output, y)
            # loss1.backward()
            self.set_step(step=2)
            output = self.forward(x)
            loss2 = criterion(output, y)
            # loss2.backward()
            self.set_step(step=3)
            output = self.forward(x)
            loss3 = criterion(output, y)
            # loss3.backward()
            self.set_step(step=4)
            output = self.forward(x)
            loss4 = criterion(output, y)
            # loss4.backward()
            (loss1+loss2+loss3+loss4).backward()
            return loss1, output.detach()

        elif (not self.baseline) and step == 'core':
            self.train_weights_and_bias(True)
            self.set_step(step='core')
            loss = criterion(self.forward(x), y)
            return loss

        elif self.baseline:

            self.train_weights_and_bias(True)
            output = self.forward(x)
            loss = criterion(output, y)
            if step == 'all':
                loss.backward()
                return loss, output.detach()
            elif step == 'core':
                return loss

    @torch.no_grad()
    def deactivate_all_grads(self):
        '''
        deactivates all requires_gradcc
        '''
        for p in self.lr_model.parameters():
            p.requires_grad = False 
            

    def chain_init(self):

        '''
        test custom initialization to align sequential subspaces
        '''

        output_subspace = None

        for l in self.lr_model:

            if hasattr(l, 'dlrt') and l.dlrt:

                if not l.fixed:

                    if output_subspace == None:

                        output_subspace = l.Us, l.U_hats

                    else:

                        l.Us[1] = output_subspace[0][0]
                        l.U_hats[1] = output_subspace[1][0]
                        output_subspace = l.Us, l.U_hats
                else:

                    if output_subspace == None:

                        output_subspace = l.Us

                    else:

                        l.Us[1] = output_subspace[0]
                        output_subspace = l.Us

    @torch.no_grad()
    def schedule_tau(self,factor):
        for layer in self.lr_modules:
            layer.tau = factor*layer.tau


    ############ OPTIMIZATION STEPS FOR DLRT

    def K_preprocess_step(self):

        self.__run_command_on_dlrt_layers(self.lr_model, 'K_preprocess_step')

    def L_preprocess_step(self):

        self.__run_command_on_dlrt_layers(self.lr_model, 'L_preprocess_step')

    def S_preprocess_step(self):

        self.__run_command_on_dlrt_layers(self.lr_model, 'S_preprocess_step')

    def K_postprocess_step(self):

        self.__run_command_on_dlrt_layers(self.lr_model, 'K_postprocess_step')

    def L_postprocess_step(self):

        self.__run_command_on_dlrt_layers(self.lr_model, 'L_postprocess_step')

    def S_postprocess_step(self):

        self.__run_command_on_dlrt_layers(self.lr_model, 'S_postprocess_step')

    def update_Q(self):

        self.__run_command_on_dlrt_layers(self.lr_model, 'update_Q')

